import time
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import numba
from scipy.sparse import csr_matrix, coo_matrix

from itertools import chain
from typing import Optional

import torch
from torch import Tensor

from torch_geometric import EdgeIndex
from torch_geometric.utils.num_nodes import maybe_num_nodes

def load_data(dataset, dataset_folder="/data/shared/zhexu/"):
    if dataset == 'cora':
        from raw_data_utils.load_cora import get_raw_text
    elif dataset == 'pubmed':
        from raw_data_utils.load_pubmed import get_raw_text
    elif dataset == 'ogbn-arxiv':
        from raw_data_utils.load_arxiv import get_raw_text
    elif dataset == 'ogbn-products':
        from raw_data_utils.load_products import get_raw_text
    else:
        exit(f'Error: Dataset {dataset} not supported')
    
    return get_raw_text(dataset_folder=dataset_folder)
    
def _get_ppr(  # pragma: no cover
    rowptr: np.ndarray,
    col: np.ndarray,
    alpha: float,
    eps: float,
    target: Optional[np.ndarray] = None,
):

    num_nodes = len(rowptr) - 1 if target is None else len(target)
    alpha_eps = alpha * eps
    js = [[0]] * num_nodes
    vals = [[0.]] * num_nodes

    for inode_uint in numba.prange(num_nodes):
        if target is None:
            inode = numba.int64(inode_uint)
        else:
            inode = target[inode_uint]

        p = {inode: 0.0}
        r = {}
        r[inode] = alpha
        q = [inode]

        while len(q) > 0:
            unode = q.pop()

            res = r[unode] if unode in r else 0
            if unode in p:
                p[unode] += res
            else:
                p[unode] = res

            r[unode] = 0
            start, end = rowptr[unode], rowptr[unode + 1]
            ucount = end - start

            for vnode in col[start:end]:
                _val = (1 - alpha) * res / ucount
                if vnode in r:
                    r[vnode] += _val
                else:
                    r[vnode] = _val

                res_vnode = r[vnode] if vnode in r else 0
                vcount = rowptr[vnode + 1] - rowptr[vnode]
                if res_vnode >= alpha_eps * vcount:
                    if vnode not in q:
                        q.append(vnode)

        js[inode_uint] = list(p.keys())
        vals[inode_uint] = list(p.values())

    return js, vals

def get_ppr(
    edge_index: Tensor,
    alpha: float = 0.2,
    eps: float = 1e-5,
    target: Optional[Tensor] = None,
    num_nodes: Optional[int] = None,
):

    _get_ppr_numba = numba.jit(nopython=True, cache=True)(_get_ppr)
    # _get_ppr_numba = numba.jit(nopython=True, parallel=True)(_get_ppr)

    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    edge_index = EdgeIndex(edge_index, sparse_size=(num_nodes, num_nodes))
    edge_index = edge_index.sort_by('row')[0]
    (rowptr, col), _ = edge_index.get_csr()

    cols, weights = _get_ppr_numba(
        rowptr.cpu().numpy(),
        col.cpu().numpy(),
        alpha,
        eps,
        None if target is None else target.cpu().numpy(),
    )

    device = edge_index.device
    col = torch.tensor(list(chain.from_iterable(cols)), device=device)
    weight = torch.tensor(list(chain.from_iterable(weights)), device=device)
    deg = torch.tensor([len(value) for value in cols], device=device)

    row = torch.arange(num_nodes) if target is None else target
    row = row.repeat_interleave(deg, output_size=col.numel())

    edge_index = torch.stack([row, col], dim=0)

    return edge_index, weight

dataset_folder = "../raw_data/"
output_folder = "../processed_data/"
datasets = ['cora']
eps = 1e-5
for dataset in datasets:
    with open(f'{output_folder}{dataset}/{dataset}_ppradj_list.txt', 'w') as fout:
        data, text, label2text = load_data(dataset, dataset_folder)
        
        num_nodes = data.x.shape[0]
        edge_index = data.edge_index.contiguous()

        print(dataset)
        print(edge_index.shape)
        print(eps)
        start = time.time()
        edge_list, edge_weights = get_ppr(edge_index, alpha=0.1, eps=eps)
        print(f'Time passed: {time.time()-start}')
        print(edge_list.shape)
        print(edge_weights.shape)

        edge_list = edge_list.transpose(1,0).numpy()
        edge_weights = edge_weights.numpy()

        adjacency_list = {}
        for i in range(num_nodes):
            adjacency_list[i] = {}
        for i, (node1, node2) in enumerate(edge_list):
            adjacency_list[node1][node2] = edge_weights[i]

        for i in adjacency_list:
            fout.write(f'{i}\t')
            neighbors = [str(x[0]) for x in sorted(adjacency_list[i].items(), key=lambda x:x[1], reverse=True)]
            fout.write(' '.join(neighbors)+'\n')
